Skip to content

[PyTorch] Batch CP attention tests in single torchrun to amortize NCC…#2965

Closed
sudhakarsingh27 wants to merge 8 commits into
NVIDIA:mainfrom
sudhakarsingh27:sudhakars/cp_test_batching_pr
Closed

[PyTorch] Batch CP attention tests in single torchrun to amortize NCC…#2965
sudhakarsingh27 wants to merge 8 commits into
NVIDIA:mainfrom
sudhakarsingh27:sudhakars/cp_test_batching_pr

Conversation

@sudhakarsingh27
Copy link
Copy Markdown
Member

@sudhakarsingh27 sudhakarsingh27 commented May 6, 2026

Design

Problem

8×H100, test_essential=True, 38 runnable CP attention configs:

  • ~554 s wall when each config runs in its own torchrun.
  • Of that, only ~216 s is the test work itself (CP fwd/bwd + assertions, ~5.7 s/config).
  • The remaining ~338 s is per-spawn overhead — Python imports + NCCL global init/teardown — paid 38 times at ~8.9 s each.

We need that overhead amortised, without changing how tests are written or how skips report.

Approach

A session-scoped fixture (_cp_batch_results) does two passes:

  1. Collect (dry-run, in-process). Walk pytest's collected items. For each item that requests _cp_batch_results, call its test function directly with a stubbed request. The body executes its inline pytest.skip(...) checks normally; if any fires, the item is dropped from the batch. Otherwise the body's final call to _run_or_fetch(...) records its kwargs in a module-level dict instead of launching a subprocess.
  2. Batch + execute. Group recorded kwargs by num_gpus_per_node, chunk into batches of CP_TEST_BATCH_SIZE (default 16), launch one torchrun per chunk. Worker (run_attention_with_cp.py) inits NCCL once, loops over configs, atomically flushes per-config results to <batch>.results.json. When pytest later runs each test for real, the body re-evaluates skips and _run_or_fetch looks up the recorded result.

How dry-run works

@pytest.fixture(scope="session")
def _cp_batch_results(request):
    items = [it for it in request.session.items
             if "_cp_batch_results" in getattr(it, "fixturenames", ())]
    _COLLECT_MODE = True
    for item in items:
        if _item_static_skip(item):
            continue
        try:
            _dry_run_item(item)
        except pytest.skip.Exception:
            pass
        except BaseException:
            pass  # surfaces in execute mode as a normal pytest error
    _COLLECT_MODE = False
    # group _COLLECTED_KWARGS by num_gpus, chunk, run torchrun batches

_dry_run_item calls the underlying function with the same parametrize values pytest would have passed:

def _dry_run_item(item):
    func = item.function
    params = dict(item.callspec.params)
    func(_DummyRequest(item.nodeid), {}, **params)

This bypasses pytest's runner entirely — no fixture setup hooks, no plugin reporters, no captured-stdout machinery.

_run_or_fetch checks a module-level _COLLECT_MODE flag:

def _run_or_fetch(request, batch_results, *, num_gpus_per_node, **worker_kwargs):
    if _COLLECT_MODE:
        _COLLECTED_KWARGS[request.node.nodeid] = dict(num_gpus=num_gpus_per_node, **worker_kwargs)
        return  # never reaches the lookup; never asserts
    entry = batch_results.get(request.node.nodeid)
    ...

In collect mode it's a recorder, in execute mode it's a result-fetcher. The test body doesn't know which mode it's in — it just calls one helper at the end.

Stubs and skip handling

Param Stub Why
request _DummyRequest(nodeid) — only request.node.nodeid _run_or_fetch only reads nodeid; test body never touches request.
_cp_batch_results {} (empty dict) _run_or_fetch returns early in collect mode, never inspects batch_results.

Inline pytest.skip("reason") raises pytest.skip.Exception. The dry-run loop catches per-item, drops the item from the batch, and moves on. In execute mode the same line raises again; pytest reports SKIPPED with the same reason.

@pytest.mark.skip and @pytest.mark.skipif(<bool_condition>) markers don't fire when calling item.function(...) directly. _item_static_skip(item) walks item.iter_markers("skip"|"skipif") and reads marker.args[0] (the condition) before the dry-run, dropping items the markers would otherwise skip.

Cost of running the body twice

For each item the body runs once during dry-run and once during execute. Skip checks are pure Python; the only non-trivial work is get_available_attention_backends, cached per nodeid via _BACKEND_CACHE so the second call is a dict hit. Measured on full test_essential=True (10272 collected items, 38 runnable): 530 cache lookups, 0.03 s total.

End-to-end pytest overhead (dry-run + collection, with torchrun stubbed): ~14 s wall, of which ~6.6 s is module-import startup, ~3.2 s is pytest per-item setup, 0.2 s is the batching infra itself. Negligible vs the GPU work it dispatches.

Performance

8×H100, test_essential=True (38 runnable configs: 34 × 2-GPU + 4 × 4-GPU). In unbatched mode each config is its own torchrun. In batched mode, configs sharing the same num_gpus_per_node are grouped into one torchrun of up to CP_TEST_BATCH_SIZE configs.

Run Torchrun spawns Wall Speedup
Unbatched 38 (one per config) ~554 s 1.0×
B=16 4 (16+16+2 @ 2GPU; 4 @ 4GPU) 274 s 2.0×
B=32 3 (32+2 @ 2GPU; 4 @ 4GPU) 248 s 2.2×
B=50 2 (34 @ 2GPU; 4 @ 4GPU) 237 s 2.3×

Where the 2× comes from

  1. 34 fewer torchrun spawns. Each saved spawn cuts ~12 s of startup (Python imports + NCCL global init/teardown) — measured directly from the wall-time delta B=50 → B=16 (37 s saved across 2 fewer spawns).
  2. ~1.2 s lower per-config work (~46 s total). Sharing NCCL global state across configs in a batch drops per-config wall from 5.7 s to 4.5 s; only the per-config CP comm-group create/destroy remains.

The spawn savings dominate.

Picking CP_TEST_BATCH_SIZE

For these 38 configs (34 × 2-GPU + 4 × 4-GPU):

  • B=16 → B=32: −27 s (one fewer 2-GPU spawn).
  • B=32 → B=50: −11 s (one more, but the merged batch is now large enough that bookkeeping eats into the savings).

16 and 32 are ballparks for this matrix — once B exceeds the largest GPU-group size (here, 34), all configs in that group already share a torchrun and further increases do nothing. With a larger config matrix (e.g. full test_essential=False ≈ 348 runnable), the same logic implies B should scale up too: pick it so the largest GPU-group has only a small number of torchrun spawns, but not so large that a single batch becomes long enough that a worker crash loses too much progress.

Knobs

Env var Effect
CP_TEST_BATCH_SIZE=N Configs per torchrun. Default 16. Set 1 to bisect.
CP_TEST_BATCH_RETRY=0 Disable singleton retry for unattributed crashes.

Adding a batched test

  1. Write the test the way you would any CP test: @pytest.mark.parametrize stack + inline pytest.skip(...) checks.
  2. Add request, _cp_batch_results to the function signature.
  3. Replace the trailing run_distributed(get_bash_arguments(...)) with _run_or_fetch(request, _cp_batch_results, num_gpus_per_node=N, ...) (kwargs become the worker's run_dpa_with_cp(**kwargs) arguments).

That's the entire wiring.

Failure semantics

Outcome What pytest sees
Inline pytest.skip(...) fires Standard SKIP (re-evaluated in execute mode and short-circuits before _run_or_fetch).
@pytest.mark.skip(if) marker fires Standard SKIP via pytest's normal path (not queued for torchrun).
Config ran, assertion failed FAIL with worker's traceback.
Assertion fired on rank > 0 only FAIL via cross-rank dist.all_reduce(ok, op=MIN).
Worker subprocess crashed before flush Each affected config retried as a singleton; real result wins, residual crashes surface as FAIL with attribution.
Dry-run itself raised Caught and ignored in the fixture; same exception fires in execute mode and pytest reports it as a normal test ERROR.

Mitigations for shared-process state

Configs in a batch share one Python process and one NCCL world, so anything that needs a clean per-test starting point is reset explicitly:

  • Per-config NCCL sub-group destruction (cp_comm_group, a2a+p2p sub-groups).
  • Reset _TRANSIENT_ENV_KEYS between configs (NVTE_FLASH_ATTN, NVTE_FUSED_ATTN, NVTE_FP8_DPA_BWD, NVTE_DPA_FP8CS_O_in_F16, NVTE_ALLOW_NONDETERMINISTIC_ALGO).
  • torch.cuda.empty_cache(), dist.barrier() between configs.
  • RNG re-seed (1234) at start of each config.
  • copy.deepcopy(model_configs_*[model]) in the worker (THD path mutates attn_mask_type).
  • Atomic per-config flush (tmp + os.replace): a partial JSON is never visible to the reader.
  • Cross-rank dist.all_reduce(ok, op=MIN) after each config so any rank's failure flips ok to False.
  • Auto-retry crashed batch entries as singletons; disable via CP_TEST_BATCH_RETRY=0.
  • arg.split("=", 1) so kwarg values containing = (paths) survive.

Edge cases

  1. request API surface during dry-run. Only request.node.nodeid is provided. A future test that uses request.config.getoption(...) or request.getfixturevalue(...) would AttributeError during dry-run. The fixture catches BaseException so the same error fires in execute mode where pytest's real request is available.
  2. @pytest.mark.skipif(condition_evaluated_at_runtime). A skipif whose condition becomes True only at execute time would not be detected by _item_static_skip. The condition still fires correctly in execute mode; we'd just have wasted one batch slot for it.
  3. get_available_attention_backends non-determinism. If this returns different values between dry-run and execute (driver state changes), a config queued by collect could skip in execute. Harmless: _run_or_fetch is never reached, the unused batch result is garbage-collected.
  4. Pytest internals. The dry-run uses item.function, item.callspec.params, and pytest.skip.Exception. Stable in pytest 7+/8+. If they shift, _dry_run_item is a 3-line shim to update.

Validation

8×H100, test_essential=True: 38 passed / 10234 skipped / 0 unrelated failures.

Stress (no regressions): single nodeid, -k <no-match>, --collect-only, small subset, CP_TEST_BATCH_SIZE=1 all behave normally.

Files

  • tests/pytorch/attention/test_attention_with_cp.py — collect/dispatch/fetch infra, dry-run helpers, test bodies updated minimally.
  • tests/pytorch/attention/run_attention_with_cp.py_init_distributed, main() batch mode, atomic per-config flush, cross-rank aggregation, per-config group teardown, copy.deepcopy of model configs, transient env reset, split("=", 1).

Type of change

  • Code refactoring (test infrastructure; no production-code change)

Checklist

  • Contributing guidelines followed
  • Functionality complete
  • Code commented where non-obvious
  • Documentation (n/a — internal test infra)
  • No new warnings
  • Existing test suite serves as input + validation
  • Existing tests pass locally

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 6, 2026

Greptile Summary

This PR amortises the per-torchrun startup cost (Python imports + NCCL global init/teardown, ~8.9 s each) for CP attention tests by batching multiple configs into a single torchrun, achieving 2–2.3× end-to-end wall-time speedup on 8×H100 with test_essential=True.

  • A session-scoped fixture (_cp_batch_results) performs a dry-run collect phase — calling each test body directly with a stub request to gather run_dpa_with_cp kwargs or detect skip conditions — then groups configs by num_gpus, chunks them, and dispatches each chunk as one torchrun batch; results are returned in a nodeid-keyed dict that each test looks up in execute mode via _run_or_fetch.
  • run_attention_with_cp.py gains a batch-mode main() entry point: _init_distributed is called once, then _run_single_config loops over configs, resetting FP8 global state and transient env vars between configs; a cross-rank dist.all_reduce(MIN) propagates per-rank failures to rank 0, which atomically flushes per-config results via a tmp-file rename.
  • run_dpa_with_cp is made re-entrant for batch mode: it detects _owns_dist to avoid re-initialising NCCL, uses copy.deepcopy on model configs to prevent mutation across configs, and destroys only the per-config cp_comm_group (not the shared process group).

Confidence Score: 4/5

Safe to merge as test infrastructure — no production code is changed and the batch machinery is well-isolated behind the session fixture.

The NCCL communicator leak (no try/finally around cp_comm_group creation) and the loss of non-rank-0 tracebacks in batch mode — both carried forward from earlier review iterations — remain. A flaky config in a large batch can exhaust NCCL communicator table entries and corrupt subsequent configs with opaque errors, and failures on non-zero ranks give no actionable traceback to the developer.

tests/pytorch/attention/run_attention_with_cp.py — specifically the run_dpa_with_cp group-lifecycle section and the _run_single_config error-reporting path for non-rank-0 failures.

Important Files Changed

Filename Overview
tests/pytorch/attention/test_attention_with_cp.py Adds the session-scoped _cp_batch_results fixture with dry-run collect/dispatch/fetch infrastructure; refactors both test functions to use _run_or_fetch. Logic is sound but the kwargs.pop("num_gpus") call in the fixture mutates dicts in _COLLECTED_KWARGS in-place, and all Python-typed values (bool, None) pass through str() serialisation in _run_batch_once — both work given the worker's string-comparison style, but are fragile for future extension.
tests/pytorch/attention/run_attention_with_cp.py Adds batch-mode entry point in main(), _init_distributed, _run_single_config, _flush_results, per-config env/FP8 reset, cross-rank all_reduce for failure aggregation, and atomic result flush. The cp_comm_group destroy still lacks a try/finally guard so a mid-config exception leaks the NCCL communicator (flagged in previous review).

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[pytest collects items] --> B[_cp_batch_results fixture starts]
    B --> C[Set _COLLECT_MODE = True]
    C --> D{For each item in session}
    D --> E[_item_static_skip?]
    E -- Yes --> D
    E -- No --> F[_dry_run_item: call test body directly]
    F --> G{pytest.skip raised?}
    G -- Yes --> D
    G -- No --> H[_run_or_fetch records kwargs in _COLLECTED_KWARGS]
    H --> D
    D -- Done --> I[_COLLECT_MODE = False]
    I --> J[Group by num_gpus, chunk into batches]
    J --> K[For each chunk: _run_one_batch]
    K --> L[_run_batch_once: write JSON, launch torchrun]
    L --> M[Worker: _init_distributed once]
    M --> N{For each cfg in batch}
    N --> O[FP8GlobalStateManager.reset + env reset]
    O --> P[_run_single_config: run_dpa_with_cp]
    P --> Q[all_reduce ok flag across ranks]
    Q --> R[rank 0: atomic flush to results.json]
    R --> S[dist.barrier + empty_cache]
    S --> N
    N -- Done --> T[pytest reads results.json]
    T --> U[fixture returns results dict: nodeid -> ok/error]
    U --> V[Each test runs normally: _run_or_fetch looks up result]
    V --> W{ok?}
    W -- Yes --> X[PASS]
    W -- No --> Y[FAIL with error]
    W -- None --> Z[SKIP collection mismatch]
Loading

Reviews (13): Last reviewed commit: "Merge branch 'main' into sudhakars/cp_te..." | Re-trigger Greptile

Comment thread tests/pytorch/attention/run_attention_with_cp.py
@sudhakarsingh27 sudhakarsingh27 force-pushed the sudhakars/cp_test_batching_pr branch from fa189b0 to 0e9fc1f Compare May 6, 2026 23:01
Comment thread tests/pytorch/attention/test_attention_with_cp.py Outdated
@sudhakarsingh27 sudhakarsingh27 force-pushed the sudhakars/cp_test_batching_pr branch 4 times, most recently from 7802ec5 to c80df5d Compare May 7, 2026 13:57
Comment on lines +147 to +153
try:
argv = get_bash_arguments(num_gpus_per_node=num_gpus, batch_config_json=batch_path)
launch_err = None
try:
run_distributed(argv)
except AssertionError as exc:
launch_err = str(exc)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Only AssertionError caught from run_distributed

subprocess.run inside run_distributed can raise FileNotFoundError, PermissionError, or OSError for OS-level failures (missing executable, exhausted file descriptors, etc.). These propagate uncaught through _run_batch_once_run_one_batch_cp_batch_results. Because the fixture is session-scoped, one such exception causes every test that depends on _cp_batch_results to surface as a fixture ERROR rather than an individual test failure. In the original code, the same OS error would fail only the one test that triggered it. Widening the except to except (AssertionError, Exception) before reading the results file would preserve the per-batch isolation benefit.

…L init

Each parametrized CP test currently spawns its own torchrun process and
pays 5-15s of NCCL init/destroy. With ~650-800 collected tests this
adds up to 1.5-3 hours of pure setup overhead.

This change introduces a session-scoped fixture that:
  1. Calls per-test ``_prepare_*`` helpers to get either a skip reason or
     a kwargs dict for the worker.
  2. Groups runnable configs by ``num_gpus`` and chunks them into batches
     of CP_TEST_BATCH_SIZE (default 16).
  3. Launches one torchrun per chunk; the worker initialises NCCL once
     and runs all configs in the chunk inside the same world.

Per-config results are flushed to JSON after every config so a crash
mid-batch still leaves earlier results intact. Set CP_TEST_BATCH_SIZE=1
to bisect a failing batch.

Also includes a small bugfix in dot_product_attention/utils.py: the
deterministic-FA3 disable condition was firing for any head_dim_qk > 128
(including inference); restrict it to is_training and large head dims.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
@sudhakarsingh27 sudhakarsingh27 force-pushed the sudhakars/cp_test_batching_pr branch from 1db76b7 to 6355f62 Compare May 7, 2026 14:14
Comment on lines +762 to +765
dist.destroy_process_group(cp_comm_group)
if cp_comm_type == "a2a+p2p":
for sg in cp_comm_sub_groups:
dist.destroy_process_group(sg)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 NCCL communicator leak on exception mid-function

cp_comm_group is created at line 238 (and up to 4 sub-groups for a2a+p2p at lines 248-250) but the destroy calls are at the very bottom of the function with no try/finally. Any exception that fires in between — a CUDA OOM, a comparison mismatch, a BaseException from cuDNN — causes _run_single_config to catch it and return (False, traceback), while the communicators are never cleaned up.

In batch mode the problem compounds: with 16 configs per torchrun and any flaky configs, leaked communicators accumulate across the whole batch. NCCL's internal communicator table has a fixed limit (typically 128), so a few hundred batched configs with occasional failures can exhaust it and corrupt subsequent configs with opaque "NCCL error: invalid usage" rather than surfacing the original failure. Wrapping the body after group creation in a try/finally guarantees cleanup.

@sudhakarsingh27
Copy link
Copy Markdown
Member Author

Tested this PR's batched CP changes on B200 (sm_103) and H100 (sm_90). The H100 run passed because most CP variants gate on sm_103 and skip on sm_90 — only 41 tests actually executed. The B200 run surfaced 143 failures that all share a single root cause — they're not independent bugs.

Failure pattern (every one of the 143 failures has this traceback referencing the same ProcessGroup instance):

torch.distributed.DistBackendError: NCCL communicator was aborted on rank 0.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "tests/pytorch/attention/run_attention_with_cp.py", line 800, in _run_single_config
    run_dpa_with_cp(**kwargs)
  File "tests/pytorch/attention/run_attention_with_cp.py", line 366, in run_dpa_with_cp
    with fp8_context:
  File "/usr/lib/python3.12/contextlib.py", line 144, in __exit__
    next(self.gen)
  File "transformer_engine/pytorch/quantization.py", line 905, in autocast
    FP8GlobalStateManager.autocast_exit(enabled, _graph=_graph)
  File "transformer_engine/pytorch/quantization.py", line 649, in autocast_exit
    cls.reduce_and_update_fp8_tensors(forward=True)
  File "transformer_engine/pytorch/quantization.py", line 554, in reduce_and_update_fp8_tensors
    cls.reduce_tensor_across_group_op_max(contiguous_amax, group)
  File "transformer_engine/pytorch/quantization.py", line 518, in reduce_tensor_across_group_op_max
    torch.distributed.all_reduce(...)

ValueError: Process group ... is not initialized in the world group map.
    Please initialize the group first.

Cascade mechanism:

  1. The _cp_batch_results fixture pre-runs CP tests in batches over NCCL.
  2. One batch's NCCL communicator aborts on rank 0 (the original error isn't surfaced as the first failure).
  3. The world process group enters "not initialized" state and doesn't recover.
  4. Subsequent tests don't find their entry in _cp_batch_results and fall through to per-test execution via run_dpa_with_cp(...).
  5. The per-test path completes its computation, but FP8 cleanup (autocast_exitreduce_and_update_fp8_tensorsall_reduce) needs the world group → fails with "not initialized" → 143 identical AssertionErrors.

Performance observation: the CP session ran ~9s/test (1886s / 206 actually-executed tests). That's torchrun-startup-plus-a-bit per test — batching gave ~0 speedup once the first batch crashed and the fall-through path took over.

Suggested fixes:

  • When a batch's NCCL group aborts, don't let remaining tests fall through to per-test execution with a dead group. Either reset the world process group between batches, or mark all remaining tests in the crashed batch as ERROR upfront.
  • Surface the original NCCL abort as the first failure (with the failing test variant) instead of letting it manifest as 143 cleanup-failure cascades.

Results summary:

Hardware Passed Failed Skipped
H100 (sm_90) 41 0 10231
B200 (sm_103) 63 143 10066

Widen the except in _run_batch_once from AssertionError to Exception
so OS-level failures from subprocess.run (FileNotFoundError when the
worker script is missing, PermissionError, OSError when fds are
exhausted, etc.) are attributed to the batch they came from instead
of escaping the session-scoped _cp_batch_results fixture and
ERROR-ing every CP test in the run.

Addresses Greptile P1 review comment on PR 2965.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
FP8GlobalStateManager retains quantizer registrations that reference
destroyed NCCL process groups, causing cascade failures when multiple
FP8 configs run in a single torchrun batch.  Reset the singleton
between configs to prevent this.

get_available_attention_backends is stateful — calling it during the
dry-run collect phase can produce different results than during the
execute phase, causing "skip divergence" where the batch collects
configs that should have been skipped.  Cache backend availability
per test node ID so the decision is consistent across phases.

Also: pass MASTER_PORT through to torchrun so parallel pytest
invocations on different GPU sets don't collide, and add [CP-BATCH]
progress logging to the batch infrastructure.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Restore run_dpa_with_cp as self-contained: detect whether dist is
already initialized and only init/destroy the global process group
when called standalone (legacy single-config mode). In batch mode
the function reuses the caller's process group and only tears down
per-config CP comm groups.

Extract _cached_backend_check helper so the backend-availability
cache is not scattered into both test bodies. Trim verbose docstrings
and inline comments down to single-line summaries.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
@sudhakarsingh27 sudhakarsingh27 force-pushed the sudhakars/cp_test_batching_pr branch from 97fcd4c to e591e02 Compare May 8, 2026 21:02
Comment on lines +838 to +839
if not ok_aggregate and ok and err is None:
err = "Failed on a non-zero rank (see subprocess stderr for traceback)"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Non-rank-0 failure traceback is swallowed and the error guidance is wrong

When a non-zero rank fails inside _run_single_config, its traceback is captured in that rank's local err variable but never transmitted to rank 0. The all_reduce propagates the ok=0 flag correctly, but rank 0 only records "Failed on a non-zero rank (see subprocess stderr for traceback)". That guidance is wrong: because _run_single_config catches the exception on rank 1, rank 1 exits cleanly and torchrun exits with code 0 — there is no traceback in subprocess stderr. A developer investigating the failure would find nothing there.

This is a regression from the original non-batched flow where rank 1's uncaught exception printed directly to torchrun's stderr and was captured by run_distributed. A minimal fix is to have the failing rank(s) print their traceback to sys.stderr before returning from _run_single_config, so it appears in torchrun's captured output even when the process exits cleanly.

sudhakarsingh27 added a commit to sudhakarsingh27/TransformerEngine that referenced this pull request May 9, 2026
Port CP test batching from sudhakars/cp_test_batching_pr (PR NVIDIA#2965).
Groups parametrized configs into batches of CP_TEST_BATCH_SIZE (default
16) and runs each batch in a single torchrun invocation, amortizing the
~9s NCCL init overhead across configs instead of paying it per test.

This is a temporary commit to validate batching under CI on the
flash_attn_pad_bw_seqs branch — intended to be reverted after the run.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
@sudhakarsingh27 sudhakarsingh27 force-pushed the sudhakars/cp_test_batching_pr branch from 51d0ba5 to 6b255e1 Compare May 12, 2026 23:56
sudhakarsingh27 added a commit to sudhakarsingh27/TransformerEngine that referenced this pull request May 14, 2026
Three changes that bring the pool's failure semantics on par with the
per-batch torchrun approach in PR NVIDIA#2965 and remove a couple of footguns:

1. Capture pool-worker stderr into a ring buffer and attach the tail to
   crash-path AssertionErrors. Equivalent in spirit to PR NVIDIA#2965's
   run_distributed() — CI JUnit XML now shows the actual cause (NCCL
   error, Python traceback, OOM) inline with the failing test, instead
   of just "pool worker died mid-request" / "timed out". A daemon
   drainer thread reads stderr line-by-line into a deque(maxlen=200)
   and also echoes to sys.stderr so pytest's per-test capture still
   gets every line. Maximum buffered footprint ~40 KB.

2. Tighten POOL_SUBMIT_TIMEOUT_SEC default 600 -> 90. On H100 the
   slowest observed per-case wall is ~15 s (p99 also 15 s, p50 ~5 s).
   90 s gives ~6x headroom over the worst observed case while still
   detecting a genuine hang within ~1.5 min instead of ~10 min. Env
   var still overrides for slower machines or expanded test matrices.

3. Optional per-case wall-time logging (NVTE_CP_POOL_TIMING=1) prints
   "[POOL-TIMING] case_idx=N world_size=W wall_s=X.XXX ok=B" to stderr
   on rank 0 only. Grep-friendly; lets future tuning recalibrate the
   timeout against the observed distribution. Off by default so normal
   runs stay quiet.

Validated: 38 passed / 0 failed in 248 s on H100, test_essential=True,
with no perf regression vs the un-patched 256 s.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
sudhakarsingh27 added a commit that referenced this pull request May 21, 2026
* Batch CP attention tests via a persistent NCCL pool

The existing test path spawns one torchrun per parametrized case, paying
NCCL init + CUDA context + Python startup on every call. With ~hundreds
of cases the launch overhead dominates wall time and was a primary driver
of the L3 timeout that prior batching PRs worked around.

This change replaces the per-case subprocess with one long-lived
torchrun per (world_size). NCCL is initialized once at session start and
reused across cases. Pytest sends one JSON request per case over rank-0
stdin; the worker dispatches to run_dpa_with_cp(**kwargs), gathers
(ok, error) from every rank, and writes one JSON response on rank-0
stdout.

run_attention_with_cp.py is left almost untouched; a new
NVTE_CP_POOL_PG=1 env var gates the dist.init_process_group() and
dist.destroy_process_group() calls so the function reuses the pool's
main PG instead of creating its own. The per-case cp_comm_group (and
a2a+p2p sub-groups) are explicitly destroyed at function exit to
prevent communicator leakage across cases.

The PoolWorker class adds two pieces of error recovery that the prior
subprocess-per-case design got for free: a select-based per-call
timeout (default 600s, NVTE_CP_POOL_TIMEOUT_SEC) and auto-respawn on
worker death or timeout. A test-level exception is reported as an
AssertionError and the pool keeps running for the next case.

Two pool sizes are needed because cp_comm_type='a2a+p2p' requires
world_size=4 and the others use world_size=2; you can't resize an
active PG. Pools are spawned lazily so a 2-GPU-only run never pays the
4-GPU init.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* Reset FP8 state and barrier between pool cases

Two resilience fixes carried over from the existing batching PR
(sudhakars/cp_test_batching_pr) without which the pool will
cascade-fail FP8 tests and silently propagate NCCL desync.

1. FP8GlobalStateManager.reset() between cases. FP8 quantizer state
   (recipe handles, autocast counters) lives in module-level globals.
   Reusing one Python process across cases otherwise carries that state
   forward. The prior batching PR landed an explicit fix for the same
   issue ("Fix FP8 cascade failures") after observing real test
   failures from this.

2. dist.barrier() after each case. If one rank's case errored before
   its last collective, the others can be stuck waiting on a comm that
   will never complete. The barrier here surfaces that immediately as
   a timeout in this case rather than letting the corruption leak into
   the next case's collectives.

Also pops the transient NVTE_* env vars run_dpa_with_cp sets at the
top of each call. run_dpa_with_cp already sets them unconditionally so
this is defensive, but cheap insurance against future variants that
might not.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* Deep-copy ModelConfig in run_dpa_with_cp

The model_configs_{flash,fused}_attn dicts are module-level and shared
across pool cases. The THD branch below rewrites config.attn_mask_type
in place (causal -> padding_causal, no_mask -> padding). With the
persistent-pool runner, the next case looking up the same model key
gets the mutated config and fails the "causal or no_mask only" assert.

Caught at benchmark time on cp_2_0 + thd, identical to the cascade the
existing batching PR (sudhakars/cp_test_batching_pr) hit and fixed the
same way in commit 6355f62.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* Skip deterministic configs incompatible with FusedAttention

Mirrors the two pre-emptive skips on the PR-batching branch:

* non-vanilla softmax with FusedAttention is not deterministic
* post_scale_bias with requires_grad is not deterministic

Without these skips, the corresponding configs propagate into the pool
worker under NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 and fail inside
run_dpa_with_cp instead of being marked SKIPPED.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* Reseed RNG between pool cases; reset before, not after

The pool worker reused RNG state across cases, which produced
small numerical drift on some non-FP8 fused-attention configs
(cp_1_0 + thd/p2p, cp_1_0 + sbhd/all_gather) compared to the
single-shot worker. Matches the per-case startup of the single-shot
worker: torch.manual_seed(1234) + torch.cuda.manual_seed(1234) at
the start of every case, alongside the existing FP8 / env / cache
resets.

Moved the reset call from the post-case finally block to the start
of _run_one so the first case is also seeded consistently with
subsequent cases. Otherwise the first case would inherit the
process-default RNG and only the second-and-later cases would be
deterministic.

Validated locally: 38 passed, 0 failed (was 36 passed, 2 failed).

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Robustify pool: capture worker stderr, tighten timeout, add timing knob

Three changes that bring the pool's failure semantics on par with the
per-batch torchrun approach in PR #2965 and remove a couple of footguns:

1. Capture pool-worker stderr into a ring buffer and attach the tail to
   crash-path AssertionErrors. Equivalent in spirit to PR #2965's
   run_distributed() — CI JUnit XML now shows the actual cause (NCCL
   error, Python traceback, OOM) inline with the failing test, instead
   of just "pool worker died mid-request" / "timed out". A daemon
   drainer thread reads stderr line-by-line into a deque(maxlen=200)
   and also echoes to sys.stderr so pytest's per-test capture still
   gets every line. Maximum buffered footprint ~40 KB.

2. Tighten POOL_SUBMIT_TIMEOUT_SEC default 600 -> 90. On H100 the
   slowest observed per-case wall is ~15 s (p99 also 15 s, p50 ~5 s).
   90 s gives ~6x headroom over the worst observed case while still
   detecting a genuine hang within ~1.5 min instead of ~10 min. Env
   var still overrides for slower machines or expanded test matrices.

3. Optional per-case wall-time logging (NVTE_CP_POOL_TIMING=1) prints
   "[POOL-TIMING] case_idx=N world_size=W wall_s=X.XXX ok=B" to stderr
   on rank 0 only. Grep-friendly; lets future tuning recalibrate the
   timeout against the observed distribution. Off by default so normal
   runs stay quiet.

Validated: 38 passed / 0 failed in 248 s on H100, test_essential=True,
with no perf regression vs the un-patched 256 s.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Address PR review: NCCL leak, stdout protocol, Windows note

Three fixes responding to #2993
review comments:

P1: NCCL communicator leak on exception (run_attention_with_cp.py)

run_dpa_with_cp() created cp_comm_group (and optionally cp_comm_sub_groups)
near the top, but the destroy_process_group() calls ran only on the
success path at the end of the function. Any exception in between
(tensor assertion, OOM, NCCL error) skipped the cleanup, leaking
communicators in pool mode. Long sessions with repeated failures
could exhaust NCCL internal tracking.

Wrap the test work in try/finally so the destroy logic always runs.
Initialise cp_comm_sub_groups = [] unconditionally so the finally
block is safe even when cp_comm_type != "a2a+p2p" (or when an assert
fires before the populate loop). Each destroy is itself try/except so
a destroy failure on one group doesn't leak the others.

P2: stdout protocol can be corrupted by interleaved chatter

torchrun and ranks 1..N share rank 0's stdout fd. Any non-rank-0
print, NCCL debug line, or torchrun status output interleaves with
the JSON response and breaks json.loads, killing the pool with a
misleading "json decode error".

Prefix every response with "[CP_POOL_RESP] " in run_attention_with_cp_pool.py
and have PoolWorker.submit() scan stdout for sentinel-prefixed lines,
echoing non-protocol lines to stderr for visibility. Bounded scan
(MAX_NOISE_LINES=1000) so a chatty worker can't stall the parent.

P2 (doc): select.select on a pipe fd is Linux/macOS only

Added a short comment noting Windows portability. CP attention tests
run on Linux GPU hosts; this is a documentation issue, not a real bug.

Validated: 38 passed / 0 failed in 270 s on H100, test_essential=True
(was 248 s pre-P2 — the +22 s is the new sentinel-scan loop's per-line
overhead at ~600 ms/case, within noise).

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [PyTorch] Fix stream race on max_logit_per_step in all-gather CP forward

In AttnFuncWithCPAndKVAllGather.forward, max_logit_per_step[i] is
written inside `with torch.cuda.stream(flash_attn_streams[i])`. For
i=1, flash_attn_streams[1] is cp_stream — i.e. *not* the default
stream. Later, at loop iteration i=2, the code reads
max_logit_per_step[1] via `torch.maximum(max_logit, max_logit_per_step[i-1])`
which runs on the default stream. Without an explicit wait_stream,
this is a read-after-write race across streams. The post-loop
`current_stream().wait_stream(cp_stream)` is too late — the race has
already fired.

The race is latent: outcome depends on stream scheduling. In a
fresh-process subprocess (one-torchrun-per-test path), streams are
cleanly initialised and timing happens to put the write before the
read. In a long-running persistent-worker process — exposed by
PR #2993's pool design — prior workloads shape stream state
differently, the read can fire before the write completes, and
max_logit ends up with stale values in some heads (~0.3 abs diff,
3/12 elements wrong on the H100 matrix).

Fix: insert `current_stream().wait_stream(flash_attn_streams[i-1])`
before the torch.maximum read. No-op when the streams are identical
(i=1 case, where flash_attn_streams[0] is current_stream), only
fires when reading from cp_stream (i=2 case).

Validated: 8xH100, test_essential=False, 348 passed / 0 failed in
27m 10s (was 323 passed + 5 failed at this commit's parent, all 5
failing on cp_comm_type=all_gather with mismatched max_logit).
The failing configs (all_gather + cp_1_0/cp_1_1 + bshd or fp16) now
pass under the pool — confirming the race was the sole root cause.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* Address PR review (R2): drop dead code in pool worker and PoolWorker

Line-level cleanups from the second reviewer pass on PR #2993. Each item
is dead/redundant; none changes behaviour. Full-matrix test_essential=False
on 8xH100 still passes 348/0 in 26m 23s after these.

run_attention_with_cp_pool.py:
- Drop _TRANSIENT_ENV_KEYS tuple + pop loop. run_dpa_with_cp already
  re-sets NVTE_FUSED_ATTN/NVTE_FLASH_ATTN unconditionally at the top
  and pops the FP8 ones itself. The pop loop was defensive against a
  hypothetical "future caller that doesn't re-set them" that doesn't
  exist.
- Drop gc.collect() after torch.cuda.empty_cache(). The cases create
  no Python reference cycles between iterations and empty_cache only
  frees CUDA blocks PyTorch already considers free; the combination
  was no-op here.
- Drop dist.barrier() after dist.gather_object(). gather_object is
  itself a collective synchronization point — if every rank reaches
  it, none is ahead. The "surface a wedged communicator here" comment
  was wishful: a wedged communicator would already wedge the gather.

test_attention_with_cp.py (PoolWorker):
- Drop _MAX_NOISE_LINES = 1000 + the scanned counter + the
  unreachable post-loop "1000+ lines" branch. select()'s deadline
  already bounds the loop; the line-count cap was redundant and
  the over-limit branch was unreachable in practice.
- Inline _stderr_tail() into _diag(). Single caller, single use.
- Drop the _stderr_thread attribute. The drainer is daemon and
  self-terminates when the pipe closes; we never read the field
  anywhere, so initialising and nulling it was bookkeeping for no
  reason.
- Drop the dead assert in submit() — _ensure_alive() on the prior
  line already guarantees proc/stdin/stdout exist.

Deferred to a follow-up:
- L8 (drop try/except around dist.destroy_process_group). Real
  semantic change: hides errors that occur when a previous test
  wedged the communicator. Worth doing but needs its own validation.
- R1 medium items M1 (module-level flag vs NVTE_CP_POOL_PG env var),
  M2 (redirect rank>0 stdout vs sentinel scan), M3 (explicit
  CUDA_VISIBLE_DEVICES per pool). Same reasoning — separate PRs.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* Address PR review (items 2+3): reuse CP groups across pool cases

world_size and the rank set don't change for the lifetime of one pool, so
recreating the world group and a2a+p2p sub-groups per case wastes ~50-100 ms
of NCCL setup each. Pre-create them once in the pool worker (new helper
_create_cp_comm_groups), stash on the run_attention_with_cp module via
module-level _pool_cp_comm_group / _pool_cp_comm_sub_groups pointers, and
reuse them from run_dpa_with_cp in pool mode. Pool teardown destroys them
once at shutdown.

Also move per-case dist.new_group() calls inside the try/finally in
run_dpa_with_cp: a failure mid-loop in the a2a+p2p sub_group population
otherwise leaks every communicator created before the failure. The finally
now only destroys groups we created locally (cp_comm_group / sub_groups
populated in the else-branch), leaving pool-owned groups alone for reuse.

cyanguwa's review feedback on PR #2993.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* Flatten try/finally wrap in run_dpa_with_cp

The Round-1 P1 NCCL-communicator-leak fix (e162a9e) wrapped the
~540-line body of run_dpa_with_cp in try/finally. The wrap itself
was tiny but it re-indented every line of the body by one level,
inflating the PR diff of run_attention_with_cp.py to ~1000 lines
against origin/main.

Items 2+3 (d15bfce) since made the wrap unnecessary:
  - In pool mode, cp_comm_group and cp_comm_sub_groups are owned by
    the pool worker (which destroys them once at pool shutdown).
    run_dpa_with_cp neither creates nor destroys them, so an
    in-body exception can't leak communicators.
  - In single-shot mode, groups are still created locally, but the
    subprocess exits at function return; NCCL releases everything
    at process teardown, so a stray exception leaks communicators
    only for the milliseconds before the process dies — a bounded
    one-off cost, not the unbounded accumulation that Round-1
    flagged for pool mode.

Removing the wrap drops the run_attention_with_cp.py diff against
origin/main from ~1000 lines to ~120 lines without changing
observable behaviour. Smoke-tested: 4 representative cases pass.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Set test_essential=True to match shipping default

Round-3 review (greptile, discussion_r3250016711) flagged that the
working tree had test_essential=False — i.e. the full ~328-config
matrix instead of the ~38-config essential subset that the rest of the
CI matrix expects. Flipping back to True so CI doesn't regress baseline
on the known H1-style cascade configs that only appear in the full
matrix.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* Retry once on pool-infrastructure failures with stderr-logged flake trace

The pool worker subprocess can die mid-case due to async NCCL aborts or
flaky 4-GPU collective state that doesn't reproduce on a fresh pool.
Without retry, these manifest as one-off CI failures attributable to
infrastructure, not the PR's content.

Add a single-attempt retry around PoolWorker.submit() that fires only
on infrastructure failure modes (pool-worker-died, timeout,
broken-pipe-pre-send). Test-assertion failures from the worker
(resp["error"]) carry full per-rank tracebacks and propagate without
retry — so a real bug still surfaces as FAILED.

Visibility: every retry attempt writes a [POOL-RETRY] line to stderr.
pytest captures per-test stderr and writes it into JUnit
<testcase>/<system-err>. A flaky test will appear as PASSED in the
case row but with a [POOL-RETRY] line in <system-err> — visible to
the reviewer, and queryable by CI dashboards looking for flake
patterns (e.g. "same test_id retries across multiple CI runs").

If both attempts die, a [POOL-RETRY-FAIL] line is also logged with
the first error's headline, then the second attempt's full traceback
propagates as the test failure.

Smoke-tested: 3 representative cases (p2p, a2a flash; p2p fused)
still PASS in 19 s.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [PyTorch] Pool: redirect non-rank-0 stdout to /dev/null; drop sentinel

Replaces the [CP_POOL_RESP] sentinel-prefix protocol with a stronger
fix at the source: on rank>0, close stdout at the fd level via dup2
to /dev/null at worker startup. Catches both Python `print` writes
and C-level (NCCL, libc, etc.) writes that the sentinel could only
mitigate by scanning + skipping non-protocol lines.

With non-rank-0 stdout silenced, rank 0's JSON line is the only
thing that reaches the parent's pipe, so PoolWorker._submit_once
collapses from a sentinel-scanning while loop to a single
select + readline + json.loads.

Closes follow-up M2 from the PR description; addresses greptile's
review comment on stdout pollution. Validated on 8xH100 with the
test_essential=True flash-attn pool path (9 passed / 55 skipped /
0 failed in 56s; no JSONDecodeError, no protocol corruption).

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* Address PR review (R3): backend-cache, pool isolation, group-kill, decode-safety

- Invalidate DotProductAttention._attention_backends between pool cases so
  per-case NVTE_FLASH_ATTN/NVTE_FUSED_ATTN toggles take effect instead of
  reusing the previous case's resolved backend.
- torch.cuda.empty_cache() after each case so a 2-GPU pool doesn't squat on
  GPUs that an overlapping 4-GPU pool needs.
- PoolWorker subprocess uses start_new_session=True; _kill() uses killpg on
  the whole process group so torchrun's rank workers don't survive as
  orphans holding CUDA/NCCL state.
- On a failed worker response, kill the pool before raising so half-aborted
  CUDA/NCCL/FP8 state from a failed case doesn't leak into the next.
- Guard json.loads with a try/except + diagnostic so any rank-0 stdout
  pollution surfaces as a clear test failure rather than a silent protocol
  desync.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
@sudhakarsingh27
Copy link
Copy Markdown
Member Author

Closing because #2993 superseded this design and was merged.

KshitijLakhani pushed a commit that referenced this pull request May 27, 2026
* Batch CP attention tests via a persistent NCCL pool

The existing test path spawns one torchrun per parametrized case, paying
NCCL init + CUDA context + Python startup on every call. With ~hundreds
of cases the launch overhead dominates wall time and was a primary driver
of the L3 timeout that prior batching PRs worked around.

This change replaces the per-case subprocess with one long-lived
torchrun per (world_size). NCCL is initialized once at session start and
reused across cases. Pytest sends one JSON request per case over rank-0
stdin; the worker dispatches to run_dpa_with_cp(**kwargs), gathers
(ok, error) from every rank, and writes one JSON response on rank-0
stdout.

run_attention_with_cp.py is left almost untouched; a new
NVTE_CP_POOL_PG=1 env var gates the dist.init_process_group() and
dist.destroy_process_group() calls so the function reuses the pool's
main PG instead of creating its own. The per-case cp_comm_group (and
a2a+p2p sub-groups) are explicitly destroyed at function exit to
prevent communicator leakage across cases.

The PoolWorker class adds two pieces of error recovery that the prior
subprocess-per-case design got for free: a select-based per-call
timeout (default 600s, NVTE_CP_POOL_TIMEOUT_SEC) and auto-respawn on
worker death or timeout. A test-level exception is reported as an
AssertionError and the pool keeps running for the next case.

Two pool sizes are needed because cp_comm_type='a2a+p2p' requires
world_size=4 and the others use world_size=2; you can't resize an
active PG. Pools are spawned lazily so a 2-GPU-only run never pays the
4-GPU init.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* Reset FP8 state and barrier between pool cases

Two resilience fixes carried over from the existing batching PR
(sudhakars/cp_test_batching_pr) without which the pool will
cascade-fail FP8 tests and silently propagate NCCL desync.

1. FP8GlobalStateManager.reset() between cases. FP8 quantizer state
   (recipe handles, autocast counters) lives in module-level globals.
   Reusing one Python process across cases otherwise carries that state
   forward. The prior batching PR landed an explicit fix for the same
   issue ("Fix FP8 cascade failures") after observing real test
   failures from this.

2. dist.barrier() after each case. If one rank's case errored before
   its last collective, the others can be stuck waiting on a comm that
   will never complete. The barrier here surfaces that immediately as
   a timeout in this case rather than letting the corruption leak into
   the next case's collectives.

Also pops the transient NVTE_* env vars run_dpa_with_cp sets at the
top of each call. run_dpa_with_cp already sets them unconditionally so
this is defensive, but cheap insurance against future variants that
might not.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* Deep-copy ModelConfig in run_dpa_with_cp

The model_configs_{flash,fused}_attn dicts are module-level and shared
across pool cases. The THD branch below rewrites config.attn_mask_type
in place (causal -> padding_causal, no_mask -> padding). With the
persistent-pool runner, the next case looking up the same model key
gets the mutated config and fails the "causal or no_mask only" assert.

Caught at benchmark time on cp_2_0 + thd, identical to the cascade the
existing batching PR (sudhakars/cp_test_batching_pr) hit and fixed the
same way in commit 6355f62.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* Skip deterministic configs incompatible with FusedAttention

Mirrors the two pre-emptive skips on the PR-batching branch:

* non-vanilla softmax with FusedAttention is not deterministic
* post_scale_bias with requires_grad is not deterministic

Without these skips, the corresponding configs propagate into the pool
worker under NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 and fail inside
run_dpa_with_cp instead of being marked SKIPPED.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* Reseed RNG between pool cases; reset before, not after

The pool worker reused RNG state across cases, which produced
small numerical drift on some non-FP8 fused-attention configs
(cp_1_0 + thd/p2p, cp_1_0 + sbhd/all_gather) compared to the
single-shot worker. Matches the per-case startup of the single-shot
worker: torch.manual_seed(1234) + torch.cuda.manual_seed(1234) at
the start of every case, alongside the existing FP8 / env / cache
resets.

Moved the reset call from the post-case finally block to the start
of _run_one so the first case is also seeded consistently with
subsequent cases. Otherwise the first case would inherit the
process-default RNG and only the second-and-later cases would be
deterministic.

Validated locally: 38 passed, 0 failed (was 36 passed, 2 failed).

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Robustify pool: capture worker stderr, tighten timeout, add timing knob

Three changes that bring the pool's failure semantics on par with the
per-batch torchrun approach in PR #2965 and remove a couple of footguns:

1. Capture pool-worker stderr into a ring buffer and attach the tail to
   crash-path AssertionErrors. Equivalent in spirit to PR #2965's
   run_distributed() — CI JUnit XML now shows the actual cause (NCCL
   error, Python traceback, OOM) inline with the failing test, instead
   of just "pool worker died mid-request" / "timed out". A daemon
   drainer thread reads stderr line-by-line into a deque(maxlen=200)
   and also echoes to sys.stderr so pytest's per-test capture still
   gets every line. Maximum buffered footprint ~40 KB.

2. Tighten POOL_SUBMIT_TIMEOUT_SEC default 600 -> 90. On H100 the
   slowest observed per-case wall is ~15 s (p99 also 15 s, p50 ~5 s).
   90 s gives ~6x headroom over the worst observed case while still
   detecting a genuine hang within ~1.5 min instead of ~10 min. Env
   var still overrides for slower machines or expanded test matrices.

3. Optional per-case wall-time logging (NVTE_CP_POOL_TIMING=1) prints
   "[POOL-TIMING] case_idx=N world_size=W wall_s=X.XXX ok=B" to stderr
   on rank 0 only. Grep-friendly; lets future tuning recalibrate the
   timeout against the observed distribution. Off by default so normal
   runs stay quiet.

Validated: 38 passed / 0 failed in 248 s on H100, test_essential=True,
with no perf regression vs the un-patched 256 s.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Address PR review: NCCL leak, stdout protocol, Windows note

Three fixes responding to #2993
review comments:

P1: NCCL communicator leak on exception (run_attention_with_cp.py)

run_dpa_with_cp() created cp_comm_group (and optionally cp_comm_sub_groups)
near the top, but the destroy_process_group() calls ran only on the
success path at the end of the function. Any exception in between
(tensor assertion, OOM, NCCL error) skipped the cleanup, leaking
communicators in pool mode. Long sessions with repeated failures
could exhaust NCCL internal tracking.

Wrap the test work in try/finally so the destroy logic always runs.
Initialise cp_comm_sub_groups = [] unconditionally so the finally
block is safe even when cp_comm_type != "a2a+p2p" (or when an assert
fires before the populate loop). Each destroy is itself try/except so
a destroy failure on one group doesn't leak the others.

P2: stdout protocol can be corrupted by interleaved chatter

torchrun and ranks 1..N share rank 0's stdout fd. Any non-rank-0
print, NCCL debug line, or torchrun status output interleaves with
the JSON response and breaks json.loads, killing the pool with a
misleading "json decode error".

Prefix every response with "[CP_POOL_RESP] " in run_attention_with_cp_pool.py
and have PoolWorker.submit() scan stdout for sentinel-prefixed lines,
echoing non-protocol lines to stderr for visibility. Bounded scan
(MAX_NOISE_LINES=1000) so a chatty worker can't stall the parent.

P2 (doc): select.select on a pipe fd is Linux/macOS only

Added a short comment noting Windows portability. CP attention tests
run on Linux GPU hosts; this is a documentation issue, not a real bug.

Validated: 38 passed / 0 failed in 270 s on H100, test_essential=True
(was 248 s pre-P2 — the +22 s is the new sentinel-scan loop's per-line
overhead at ~600 ms/case, within noise).

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [PyTorch] Fix stream race on max_logit_per_step in all-gather CP forward

In AttnFuncWithCPAndKVAllGather.forward, max_logit_per_step[i] is
written inside `with torch.cuda.stream(flash_attn_streams[i])`. For
i=1, flash_attn_streams[1] is cp_stream — i.e. *not* the default
stream. Later, at loop iteration i=2, the code reads
max_logit_per_step[1] via `torch.maximum(max_logit, max_logit_per_step[i-1])`
which runs on the default stream. Without an explicit wait_stream,
this is a read-after-write race across streams. The post-loop
`current_stream().wait_stream(cp_stream)` is too late — the race has
already fired.

The race is latent: outcome depends on stream scheduling. In a
fresh-process subprocess (one-torchrun-per-test path), streams are
cleanly initialised and timing happens to put the write before the
read. In a long-running persistent-worker process — exposed by
PR #2993's pool design — prior workloads shape stream state
differently, the read can fire before the write completes, and
max_logit ends up with stale values in some heads (~0.3 abs diff,
3/12 elements wrong on the H100 matrix).

Fix: insert `current_stream().wait_stream(flash_attn_streams[i-1])`
before the torch.maximum read. No-op when the streams are identical
(i=1 case, where flash_attn_streams[0] is current_stream), only
fires when reading from cp_stream (i=2 case).

Validated: 8xH100, test_essential=False, 348 passed / 0 failed in
27m 10s (was 323 passed + 5 failed at this commit's parent, all 5
failing on cp_comm_type=all_gather with mismatched max_logit).
The failing configs (all_gather + cp_1_0/cp_1_1 + bshd or fp16) now
pass under the pool — confirming the race was the sole root cause.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* Address PR review (R2): drop dead code in pool worker and PoolWorker

Line-level cleanups from the second reviewer pass on PR #2993. Each item
is dead/redundant; none changes behaviour. Full-matrix test_essential=False
on 8xH100 still passes 348/0 in 26m 23s after these.

run_attention_with_cp_pool.py:
- Drop _TRANSIENT_ENV_KEYS tuple + pop loop. run_dpa_with_cp already
  re-sets NVTE_FUSED_ATTN/NVTE_FLASH_ATTN unconditionally at the top
  and pops the FP8 ones itself. The pop loop was defensive against a
  hypothetical "future caller that doesn't re-set them" that doesn't
  exist.
- Drop gc.collect() after torch.cuda.empty_cache(). The cases create
  no Python reference cycles between iterations and empty_cache only
  frees CUDA blocks PyTorch already considers free; the combination
  was no-op here.
- Drop dist.barrier() after dist.gather_object(). gather_object is
  itself a collective synchronization point — if every rank reaches
  it, none is ahead. The "surface a wedged communicator here" comment
  was wishful: a wedged communicator would already wedge the gather.

test_attention_with_cp.py (PoolWorker):
- Drop _MAX_NOISE_LINES = 1000 + the scanned counter + the
  unreachable post-loop "1000+ lines" branch. select()'s deadline
  already bounds the loop; the line-count cap was redundant and
  the over-limit branch was unreachable in practice.
- Inline _stderr_tail() into _diag(). Single caller, single use.
- Drop the _stderr_thread attribute. The drainer is daemon and
  self-terminates when the pipe closes; we never read the field
  anywhere, so initialising and nulling it was bookkeeping for no
  reason.
- Drop the dead assert in submit() — _ensure_alive() on the prior
  line already guarantees proc/stdin/stdout exist.

Deferred to a follow-up:
- L8 (drop try/except around dist.destroy_process_group). Real
  semantic change: hides errors that occur when a previous test
  wedged the communicator. Worth doing but needs its own validation.
- R1 medium items M1 (module-level flag vs NVTE_CP_POOL_PG env var),
  M2 (redirect rank>0 stdout vs sentinel scan), M3 (explicit
  CUDA_VISIBLE_DEVICES per pool). Same reasoning — separate PRs.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* Address PR review (items 2+3): reuse CP groups across pool cases

world_size and the rank set don't change for the lifetime of one pool, so
recreating the world group and a2a+p2p sub-groups per case wastes ~50-100 ms
of NCCL setup each. Pre-create them once in the pool worker (new helper
_create_cp_comm_groups), stash on the run_attention_with_cp module via
module-level _pool_cp_comm_group / _pool_cp_comm_sub_groups pointers, and
reuse them from run_dpa_with_cp in pool mode. Pool teardown destroys them
once at shutdown.

Also move per-case dist.new_group() calls inside the try/finally in
run_dpa_with_cp: a failure mid-loop in the a2a+p2p sub_group population
otherwise leaks every communicator created before the failure. The finally
now only destroys groups we created locally (cp_comm_group / sub_groups
populated in the else-branch), leaving pool-owned groups alone for reuse.

cyanguwa's review feedback on PR #2993.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* Flatten try/finally wrap in run_dpa_with_cp

The Round-1 P1 NCCL-communicator-leak fix (e162a9e) wrapped the
~540-line body of run_dpa_with_cp in try/finally. The wrap itself
was tiny but it re-indented every line of the body by one level,
inflating the PR diff of run_attention_with_cp.py to ~1000 lines
against origin/main.

Items 2+3 (d15bfce) since made the wrap unnecessary:
  - In pool mode, cp_comm_group and cp_comm_sub_groups are owned by
    the pool worker (which destroys them once at pool shutdown).
    run_dpa_with_cp neither creates nor destroys them, so an
    in-body exception can't leak communicators.
  - In single-shot mode, groups are still created locally, but the
    subprocess exits at function return; NCCL releases everything
    at process teardown, so a stray exception leaks communicators
    only for the milliseconds before the process dies — a bounded
    one-off cost, not the unbounded accumulation that Round-1
    flagged for pool mode.

Removing the wrap drops the run_attention_with_cp.py diff against
origin/main from ~1000 lines to ~120 lines without changing
observable behaviour. Smoke-tested: 4 representative cases pass.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Set test_essential=True to match shipping default

Round-3 review (greptile, discussion_r3250016711) flagged that the
working tree had test_essential=False — i.e. the full ~328-config
matrix instead of the ~38-config essential subset that the rest of the
CI matrix expects. Flipping back to True so CI doesn't regress baseline
on the known H1-style cascade configs that only appear in the full
matrix.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* Retry once on pool-infrastructure failures with stderr-logged flake trace

The pool worker subprocess can die mid-case due to async NCCL aborts or
flaky 4-GPU collective state that doesn't reproduce on a fresh pool.
Without retry, these manifest as one-off CI failures attributable to
infrastructure, not the PR's content.

Add a single-attempt retry around PoolWorker.submit() that fires only
on infrastructure failure modes (pool-worker-died, timeout,
broken-pipe-pre-send). Test-assertion failures from the worker
(resp["error"]) carry full per-rank tracebacks and propagate without
retry — so a real bug still surfaces as FAILED.

Visibility: every retry attempt writes a [POOL-RETRY] line to stderr.
pytest captures per-test stderr and writes it into JUnit
<testcase>/<system-err>. A flaky test will appear as PASSED in the
case row but with a [POOL-RETRY] line in <system-err> — visible to
the reviewer, and queryable by CI dashboards looking for flake
patterns (e.g. "same test_id retries across multiple CI runs").

If both attempts die, a [POOL-RETRY-FAIL] line is also logged with
the first error's headline, then the second attempt's full traceback
propagates as the test failure.

Smoke-tested: 3 representative cases (p2p, a2a flash; p2p fused)
still PASS in 19 s.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [PyTorch] Pool: redirect non-rank-0 stdout to /dev/null; drop sentinel

Replaces the [CP_POOL_RESP] sentinel-prefix protocol with a stronger
fix at the source: on rank>0, close stdout at the fd level via dup2
to /dev/null at worker startup. Catches both Python `print` writes
and C-level (NCCL, libc, etc.) writes that the sentinel could only
mitigate by scanning + skipping non-protocol lines.

With non-rank-0 stdout silenced, rank 0's JSON line is the only
thing that reaches the parent's pipe, so PoolWorker._submit_once
collapses from a sentinel-scanning while loop to a single
select + readline + json.loads.

Closes follow-up M2 from the PR description; addresses greptile's
review comment on stdout pollution. Validated on 8xH100 with the
test_essential=True flash-attn pool path (9 passed / 55 skipped /
0 failed in 56s; no JSONDecodeError, no protocol corruption).

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* Address PR review (R3): backend-cache, pool isolation, group-kill, decode-safety

- Invalidate DotProductAttention._attention_backends between pool cases so
  per-case NVTE_FLASH_ATTN/NVTE_FUSED_ATTN toggles take effect instead of
  reusing the previous case's resolved backend.
- torch.cuda.empty_cache() after each case so a 2-GPU pool doesn't squat on
  GPUs that an overlapping 4-GPU pool needs.
- PoolWorker subprocess uses start_new_session=True; _kill() uses killpg on
  the whole process group so torchrun's rank workers don't survive as
  orphans holding CUDA/NCCL state.
- On a failed worker response, kill the pool before raising so half-aborted
  CUDA/NCCL/FP8 state from a failed case doesn't leak into the next.
- Guard json.loads with a try/except + diagnostic so any rank-0 stdout
  pollution surfaces as a clear test failure rather than a silent protocol
  desync.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants